require(mice)
## Warning: package 'mice' was built under R version 4.1.3
require(lattice)
library(tidyverse)
library(rMIDAS)
# set_python_env(python ="/opt/anaconda3/bin/python")
set_python_env(x ="C:\\ProgramData\\Anaconda3\\",type = "conda")
## [1] TRUE
library(ggplot2)
library(gridExtra)
library("GGally")
library(gdata)
data <- read.csv("https://raw.githubusercontent.com/MIDASverse/MIDASpy/master/Examples/adult_data.csv",
# colClasses=c("NULL",NA,NA,NA),
row.names = 1)[1:3000, ]
head(data)
## age workclass fnlwgt education education_num marital_status
## 0 39 State-gov 77516 Bachelors 13 Never-married
## 1 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse
## 2 38 Private 215646 HS-grad 9 Divorced
## 3 53 Private 234721 11th 7 Married-civ-spouse
## 4 28 Private 338409 Bachelors 13 Married-civ-spouse
## 5 37 Private 284582 Masters 14 Married-civ-spouse
## occupation relationship race sex capital_gain capital_loss
## 0 Adm-clerical Not-in-family White Male 2174 0
## 1 Exec-managerial Husband White Male 0 0
## 2 Handlers-cleaners Not-in-family White Male 0 0
## 3 Handlers-cleaners Husband Black Male 0 0
## 4 Prof-specialty Wife Black Female 0 0
## 5 Exec-managerial Wife White Female 0 0
## hours_per_week native_country class_labels
## 0 40 United-States <=50K
## 1 13 United-States <=50K
## 2 40 United-States <=50K
## 3 40 United-States <=50K
## 4 40 Cuba <=50K
## 5 40 United-States <=50K
adult_cat <- c('workclass','marital_status','relationship','race','education','occupation','native_country')
adult_bin <- c('sex','class_labels')
adult_num <- c('age','fnlwgt','education_num','capital_gain','capital_loss','hours_per_week')
for(col in c(adult_bin,adult_cat)){
data[,col] <- as.factor(data[,col])
}
# qplot(data$workclass)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$marital_status)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$relationship)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$race)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$education)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$occupation)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$native_country)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
#### Create Missing Data
miss_data <- add_missingness(data, prop = 0.2)
miss_data <- as.data.frame(miss_data)
# miss_index <- which(is.na(miss_data[,"reg"]))
# view miss number of miss data by coluemns
print(sapply(miss_data, function(x) sum(is.na(x))))
## age workclass fnlwgt education education_num
## 626 592 571 627 593
## marital_status occupation relationship race sex
## 639 580 601 606 611
## capital_gain capital_loss hours_per_week native_country class_labels
## 597 578 595 578 591
#### Imputing Data with missRanger
library(missRanger)
impt_ranger_data <- replicate(
10,
as.data.frame(missRanger(miss_data, verbose = 0, num.trees = 100)),
simplify = FALSE
)
for (i in 1:10){
for(cat in c(adult_bin,adult_cat)){
impt_ranger_data[[i]][,cat] <- round(impt_ranger_data[[i]][,cat])
}
}
imp <- mice(miss_data, print=F)
meth <- imp$meth
meth[adult_cat] <- "cart"
meth[adult_bin] <- 'rf'
meth[adult_num] <- "rf"
imp <- mice(miss_data, m=10, method = meth, print=F)
imp20 <- mice.mids(imp, maxit=15, print=F)
impt_mice_data <- list()
for (i in 1:10){
impt_mice <- mice::complete(imp20,action=i)
impt_mice_data <- append(impt_mice_data,list(impt_mice))
}
data_conv <- rMIDAS::convert(miss_data,
bin_cols = adult_bin,
cat_cols = adult_cat,
minmax_scale = TRUE)
# Train the model for 20 epochs
rmidas_train <- rMIDAS::train(data_conv,
training_epochs = 20,
layer_structure = c(128,128),
input_drop = 0.75,
seed = 89)
# Generate 10 imputed datasets
impt_rmidas_data <- rMIDAS::complete(rmidas_train, m = 10,fast = TRUE)
create_compare_data <- function(df,miss_df,impt_df_list,col,m=10,
method="mice", sp_impt="sex"){
# refer:https://cran.r-project.org/web/packages/gdata/vignettes/mapLevels.pdf
map <- mapLevels(x=factor(df$sex))
# we only need to compare the missing values
# df$sex <- as.factor(as.numeric(df$sex))
miss_df <- as.data.frame(miss_df)
miss_index <- which(is.na(miss_df[,col]))
na_count <- apply(miss_df[miss_index,], 1, function(x) sum(is.na(x)))
df <- df[miss_index,]
df["source"] <- rep("True",length(miss_index))
df$na_count <- rep("True(0 na)",length(miss_index))
for(i in 1:m){
df2 <- impt_df_list[[i]]
df2 <- df2[miss_index,]
df2["source"] <- rep(method,length(miss_index))
df2$na_count <- na_count
if(sp_impt=="method"){
df2["source"] <- rep(paste(method,i,sep = "-"),length(miss_index))
}
df <- rbind(df2,df)
}
# convert integer to boys and girls
# int <- as.integer(df$sex)
# mapLevels(x=int) <- map
# df$sex <- int
if (sp_impt=="sex"){
df$source <- apply( df[ ,c("sex","source")] , 1 , paste , collapse = "-" )
}
# print(head(df))
df
}
library(scales)
library(caret)
library(gdata)
ggplotConfusionMatrix <- function(m, col_names, method_name){
#https://stackoverflow.com/questions/51410405/ggplot2-confusion-matrix-geom-text-labeling
mytitle <- paste(method_name,"Accuracy", percent_format()(m$overall[1]),
"Kappa", percent_format()(m$overall[2]))
data_c <- mutate(group_by(as.data.frame(m$table), Reference ), percentage =
percent(Freq/sum(Freq)))
p <-
ggplot(data = data_c,
aes(x = Reference, y = Prediction)) +
geom_tile(aes(fill = Freq), colour = "white") +
scale_fill_gradient(low = "white", high = "green") +
geom_text(aes(x = Reference, y = Prediction, label = percentage)) +
scale_x_discrete(labels=col_names,guide = guide_axis(angle = 45))+
scale_y_discrete(labels=col_names)+
# theme(legend.position = "none") +
ggtitle(mytitle)
return(p)
}
plot_confusion_matrix <- function(impt_data_list, data, miss_df, col,method_name, m=10){
miss_df <- as.data.frame(miss_df)
miss_index <- which(is.na(miss_df[,col]))
pred_values <- c()
for (i in 1:m){
pred <- impt_data_list[[i]]
pred <- pred[,col]
pred_values <- c(pred_values,pred[miss_index])
}
true_labels <- data[miss_index,col]
pred_values <- as.factor(pred_values)
true_labels <- as.factor(as.numeric(as.factor(rep(true_labels,m))))
pred_values <- factor(pred_values,levels = levels(true_labels))
true_labels <- factor(true_labels,levels = levels(true_labels))
# print(pred_values)
# print(true_labels)
cfm <- confusionMatrix(true_labels,pred_values)
map <- mapLevels(x=as.factor(data[,col]))
ggplotConfusionMatrix(cfm,names(map),method_name)
}
qplot(data$marital_status)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "marital_status",method_name="ranger")
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="marital_status",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="marital_status",method_name="rmidas")
qplot(data$workclass)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "workclass",method_name="ranger")
## Warning: Removed 14 rows containing missing values (geom_text).
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="workclass",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="workclass",method_name="rmidas")
qplot(data$relationship)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "relationship",method_name="ranger")
## Warning: Removed 6 rows containing missing values (geom_text).
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="relationship",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="relationship",method_name="rmidas")
## race
qplot(data$race)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "race",method_name="ranger")
## Warning: Removed 10 rows containing missing values (geom_text).
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="race",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="race",method_name="rmidas")
## education
qplot(data$education)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "education",method_name="ranger")
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="education",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="education",method_name="rmidas")
## Warning: Removed 32 rows containing missing values (geom_text).
## occupation
qplot(data$occupation)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "occupation",method_name="ranger")
## Warning: Removed 42 rows containing missing values (geom_text).
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="occupation",method_name="mice")
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="occupation",method_name="rmidas")
## native_country
qplot(data$native_country)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "native_country",method_name="ranger")
## Warning: Removed 338 rows containing missing values (geom_text).
plot_confusion_matrix(impt_mice_data,data,miss_data,col ="native_country",method_name="mice")
## Warning: Removed 104 rows containing missing values (geom_text).
plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="native_country",method_name="rmidas")
## Warning: Removed 390 rows containing missing values (geom_text).
df_mice_wgt <- create_compare_data(data,miss_data,impt_mice_data,col = "hours_per_week",method = "mice",sp_impt="method")
ggplot(df_mice_wgt, aes(age,hours_per_week, colour = source))+geom_point(alpha=0.4)+stat_smooth()
df_mice_wgt <- create_compare_data(data,miss_data,impt_mice_data,col = "hours_per_week",method = "mice",sp_impt="sex")
ggplot(df_mice_wgt, aes(age,hours_per_week, colour = source))+geom_point(alpha=0.4)+stat_smooth()
miss_index <- which(is.na(miss_data$age))
for (i in 1:10){
sex <- factor(data$sex[miss_index])
g1 <- qplot(data$age[miss_index],impt_mice_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
ylab("mice age") + xlab("data age")+theme(legend.position = "top")
g2 <- qplot(data$age[miss_index],impt_ranger_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
ylab("ranger age") + xlab("data age")+theme(legend.position = "top")
g3 <- qplot(data$age[miss_index],impt_rmidas_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
ylab("midas age") + xlab("data age")+theme(legend.position = "top")
grid.arrange(g1, g2,g3, ncol=3)
}